import dgl
import torch
import torch.nn.functional as F
import numpy
import argparse
import time
import numpy as np
import os
import logging
from dataset import Dataset
from sklearn.metrics import f1_score, accuracy_score, recall_score, roc_auc_score, precision_score, confusion_matrix
from BWGNN import *
from sklearn.model_selection import train_test_split

# 设置随机种子为72
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    dgl.random.seed(seed)
    
# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)


def train(model, g, args):
    # 设置日志
    log_dir = "logs"
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_file = os.path.join(log_dir, f"train_log_{args.dataset}_seed{args.seed}.txt")
    logging.basicConfig(filename=log_file, level=logging.INFO, 
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    
    features = g.ndata['feature']
    labels = g.ndata['label']
    
    # 首先过滤掉标签为2的样本
    valid_indices = []
    for i in range(len(labels)):
        if labels[i] != 2:
            valid_indices.append(i)
    
    logging.info(f"过滤标签为2的样本后，剩余样本数量: {len(valid_indices)}")
    
    index = valid_indices
    if dataset_name == 'amazon':
        index = [i for i in valid_indices if i >= 3305]

    # 常规划分验证集和测试集
    idx_train, idx_rest, y_train, y_rest = train_test_split(index, labels[index], stratify=labels[index],
                                                            train_size=args.train_ratio,
                                                            random_state=args.seed, shuffle=True)
    idx_valid, idx_test, y_valid, y_test = train_test_split(idx_rest, y_rest, stratify=y_rest,
                                                            test_size=0.67,
                                                            random_state=args.seed, shuffle=True)
    
    # 统计训练集中的正负样本数量
    pos_samples = [i for i in idx_train if labels[i] == 1]
    neg_samples = [i for i in idx_train if labels[i] == 0]
    
    # 确保索引是整数列表
    idx_train = [int(i) for i in idx_train]
    idx_valid = [int(i) for i in idx_valid]
    idx_test = [int(i) for i in idx_test]
    
    train_mask = torch.zeros(len(labels), dtype=torch.bool)
    val_mask = torch.zeros(len(labels), dtype=torch.bool)
    test_mask = torch.zeros(len(labels), dtype=torch.bool)

    for i in idx_train:
        train_mask[i] = True
    for i in idx_valid:
        val_mask[i] = True
    for i in idx_test:
        test_mask[i] = True
    
    logging.info('训练集正样本数: %d, 负样本数: %d', len(pos_samples), len(neg_samples))
    logging.info('train/dev/test samples: %d, %d, %d', train_mask.sum().item(), val_mask.sum().item(), test_mask.sum().item())
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_auc = 0.0
    best_test_auc, best_test_f1, best_test_pre, best_test_rec = 0.0, 0.0, 0.0, 0.0
    best_test_acc0, best_test_acc1, best_test_gmean = 0.0, 0.0, 0.0
    best_test_epoch = 0

    # 计算正负样本比例权重，防止除以零的情况
    pos_count = labels[train_mask].sum().item()
    neg_count = (1-labels[train_mask]).sum().item()
    weight = neg_count / max(pos_count, 1)
    logging.info('cross entropy weight: %f', weight)
    
    time_start = time.time()
    for e in range(args.epoch):
        model.train()
        logits,_ = model(features,0)
        
        # 使用nll_loss训练模型
        loss = F.nll_loss(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        model.eval()
        # 从main_dual.py借鉴的验证和测试代码
        out = logits
        probs = torch.exp(out).cpu().detach()  # 将log_softmax转换回概率
        
        # 验证集评估
        f1, thres = get_best_f1(labels[val_mask], probs[val_mask])
        preds = numpy.zeros_like(labels.cpu().numpy())
        preds[probs[:, 1].detach().cpu().numpy() > thres] = 1
        
        # 过滤掉标签为2的样本
        valid_val_mask = val_mask & (labels != 2)
        valid_test_mask = test_mask & (labels != 2)
        
        # 计算验证集的ACC1和ACC0
        val_pos_mask = (labels[valid_val_mask] == 1)
        val_neg_mask = (labels[valid_val_mask] == 0)
        
        val_acc1 = accuracy_score(labels[valid_val_mask][val_pos_mask].cpu().numpy(), 
                                  preds[valid_val_mask.cpu().numpy()][val_pos_mask.cpu().numpy()]) if val_pos_mask.sum() > 0 else 0
        val_acc0 = accuracy_score(labels[valid_val_mask][val_neg_mask].cpu().numpy(), 
                                  preds[valid_val_mask.cpu().numpy()][val_neg_mask.cpu().numpy()]) if val_neg_mask.sum() > 0 else 0
        
        # 计算验证集的召回率
        val_rec1 = recall_score(labels[valid_val_mask][val_pos_mask].cpu().numpy(), 
                                preds[valid_val_mask.cpu().numpy()][val_pos_mask.cpu().numpy()]) if val_pos_mask.sum() > 0 else 0
        val_rec0 = 1 - (preds[valid_val_mask.cpu().numpy()][val_neg_mask.cpu().numpy()].sum() / val_neg_mask.sum().item()) if val_neg_mask.sum() > 0 else 0
        
        # 计算验证集的Gmean
        val_gmean = geometric_mean(val_rec0, val_rec1)
        
        # 测试集评估
        test_labels = labels[valid_test_mask].cpu().numpy()
        test_preds = preds[valid_test_mask.cpu().numpy()]
        
        trec = recall_score(test_labels, test_preds) if len(test_labels) > 0 else 0
        tpre = precision_score(test_labels, test_preds) if len(test_labels) > 0 else 0
        tmf1 = f1_score(test_labels, test_preds, average='macro') if len(test_labels) > 0 else 0
        tauc = roc_auc_score(test_labels, probs[valid_test_mask][:, 1].detach().cpu().numpy()) if len(test_labels) > 0 else 0
        
        # 计算测试集的ACC1和ACC0
        test_pos_mask = (labels[valid_test_mask] == 1)
        test_neg_mask = (labels[valid_test_mask] == 0)
        
        test_acc1 = accuracy_score(labels[valid_test_mask][test_pos_mask].cpu().numpy(), 
                                  preds[valid_test_mask.cpu().numpy()][test_pos_mask.cpu().numpy()]) if test_pos_mask.sum() > 0 else 0
        test_acc0 = accuracy_score(labels[valid_test_mask][test_neg_mask].cpu().numpy(), 
                                  preds[valid_test_mask.cpu().numpy()][test_neg_mask.cpu().numpy()]) if test_neg_mask.sum() > 0 else 0
        
        # 计算测试集的召回率
        test_rec1 = recall_score(labels[valid_test_mask][test_pos_mask].cpu().numpy(), 
                               preds[valid_test_mask.cpu().numpy()][test_pos_mask.cpu().numpy()]) if test_pos_mask.sum() > 0 else 0
        test_rec0 = 1 - (preds[valid_test_mask.cpu().numpy()][test_neg_mask.cpu().numpy()].sum() / test_neg_mask.sum().item()) if test_neg_mask.sum() > 0 else 0
        
        # 计算测试集的Gmean
        test_gmean = geometric_mean(test_rec0, test_rec1)

        # 如果当前epoch的验证AUC更好，则更新最佳模型
        if f1 > best_val_auc:
            best_val_auc = f1
            best_test_auc = tauc
            best_test_f1 = tmf1
            best_test_pre = tpre
            best_test_rec = trec
            best_test_acc0 = test_acc0
            best_test_acc1 = test_acc1
            best_test_gmean = test_gmean
            best_test_epoch = e
            
        # 记录验证集指标
        logging.info('Epoch %d, loss: %.4f, val mf1: %.4f, val ACC1: %.4f, val ACC0: %.4f, val Gmean: %.4f (best %.4f)', 
                     e, loss, f1, val_acc1, val_acc0, val_gmean, best_val_auc)
        
        # 记录测试集指标
        if e % 10 == 0:  # 每10个epoch记录一次测试结果
            logging.info('Test: REC %.4f PRE %.4f MF1 %.4f AUC %.4f ACC1 %.4f ACC0 %.4f Gmean %.4f', 
                        trec, tpre, tmf1, tauc, test_acc1, test_acc0, test_gmean)

    time_end = time.time()
    logging.info('time cost: %.2f s', time_end - time_start)
    logging.info('Best Epoch %d, Test: REC %.2f PRE %.2f MF1 %.2f AUC %.2f ACC1 %.2f ACC0 %.2f Gmean %.2f', 
                best_test_epoch, best_test_rec*100, best_test_pre*100, best_test_f1*100, best_test_auc*100, 
                best_test_acc1*100, best_test_acc0*100, best_test_gmean*100)
    
    return best_test_f1, best_test_auc, best_test_acc1, best_test_acc0, best_test_gmean


# threshold adjusting for best macro f1
def get_best_f1(labels, probs):
    # 确保数据是在CPU上
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    if isinstance(probs, torch.Tensor):
        probs = probs.detach().cpu().numpy()
        
    best_f1, best_thre = 0, 0
    for thres in np.linspace(0.05, 0.95, 19):
        preds = np.zeros_like(labels)
        preds[probs[:,1] > thres] = 1
        mf1 = f1_score(labels, preds, average='macro')
        if mf1 > best_f1:
            best_f1 = mf1
            best_thre = thres
    return best_f1, best_thre


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='BWGNN')
    parser.add_argument("--dataset", type=str, default="amazon",
                        help="Dataset for this model (yelp/amazon/tfinance/tsocial)")
    parser.add_argument("--train_ratio", type=float, default=0.4, help="Training ratio")
    parser.add_argument("--hid_dim", type=int, default=64, help="Hidden layer dimension")
    parser.add_argument("--order", type=int, default=2, help="Order C in Beta Wavelet")
    parser.add_argument("--homo", type=int, default=1, help="1 for BWGNN(Homo) and 0 for BWGNN(Hetero)")
    parser.add_argument("--epoch", type=int, default=100, help="The max number of epochs")
    parser.add_argument("--run", type=int, default=1, help="Running times")
    parser.add_argument("--seed", type=int, default=32, help="Random seed")

    args = parser.parse_args()
    print(args)
    
    # 设置随机种子
    set_seed(args.seed)
    
    dataset_name = args.dataset
    homo = args.homo
    order = args.order
    h_feats = args.hid_dim
    graph = Dataset(dataset_name, homo).graph
    in_feats = graph.ndata['feature'].shape[1]
    num_classes = 2

    if args.run == 1:
        if homo:
            model = BWGNN(in_feats, h_feats, num_classes, graph, graph,graph, d=order)
        else:
            model = BWGNN_Hetero(in_feats, h_feats, num_classes, graph, d=order)
        train(model, graph, args)

    else:
        final_mf1s, final_aucs, final_acc1s, final_acc0s, final_gmeans = [], [], [], [], []
        for tt in range(args.run):
            if homo:
                model = BWGNN(in_feats, h_feats, num_classes, graph, d=order)
            else:
                model = BWGNN_Hetero(in_feats, h_feats, num_classes, graph, d=order)
            mf1, auc, acc1, acc0, gmean = train(model, graph, args)
            final_mf1s.append(mf1)
            final_aucs.append(auc)
            final_acc1s.append(acc1)
            final_acc0s.append(acc0)
            final_gmeans.append(gmean)
            
        final_mf1s = np.array(final_mf1s)
        final_aucs = np.array(final_aucs)
        final_acc1s = np.array(final_acc1s)
        final_acc0s = np.array(final_acc0s)
        final_gmeans = np.array(final_gmeans)
        
        logging.info('MF1-mean: %.2f, MF1-std: %.2f, AUC-mean: %.2f, AUC-std: %.2f', 
                   100 * np.mean(final_mf1s), 100 * np.std(final_mf1s),
                   100 * np.mean(final_aucs), 100 * np.std(final_aucs))
        logging.info('ACC1-mean: %.2f, ACC1-std: %.2f, ACC0-mean: %.2f, ACC0-std: %.2f', 
                   100 * np.mean(final_acc1s), 100 * np.std(final_acc1s),
                   100 * np.mean(final_acc0s), 100 * np.std(final_acc0s))
        logging.info('Gmean-mean: %.2f, Gmean-std: %.2f', 
                   100 * np.mean(final_gmeans), 100 * np.std(final_gmeans))
